Skip to content

Refactor pipeline_parallel.py for graph PP reuse#2724

Draft
sanketpurandare wants to merge 1 commit intosanketpurandare/stack/2from
sanketpurandare/stack/3
Draft

Refactor pipeline_parallel.py for graph PP reuse#2724
sanketpurandare wants to merge 1 commit intosanketpurandare/stack/2from
sanketpurandare/stack/3

Conversation

@sanketpurandare
Copy link
Copy Markdown
Contributor

@sanketpurandare sanketpurandare commented Mar 27, 2026

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 27, 2026
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from 2d2fb54 to bd097f8 Compare March 27, 2026 00:49
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main March 27, 2026 01:13
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from bd097f8 to d3d14c2 Compare March 27, 2026 01:14
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 March 27, 2026 01:14
Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add more detailed description about why making these changes.

Comment on lines 40 to 48
__all__ = [
"pipeline_llm",
"get_pipeline_metadata",
"build_pipeline_schedule",
"generate_llm_fqn_per_model_part",
"pipeline_module_split",
"split_module",
"get_pp_rank_to_stage_indices_mapping",
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do all these fields need to be public and exposed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this should have been marked as draft PR, it is not ready for review yet.

@tianyu-l tianyu-l requested a review from acisseJZhong March 27, 2026 02:17
@tianyu-l
Copy link
Copy Markdown
Contributor

@claude review this PR

@claude
Copy link
Copy Markdown

claude bot commented Mar 27, 2026

Claude finished @tianyu-l's task in 3m 25s —— View job


Review of PR #2724: Refactor pipeline_parallel.py for graph PP reuse

  • Gather context and read the diff
  • Read the full modified file and base version
  • Analyze changes and write review

Summary

This PR refactors pipeline_parallel.py by extracting three nested/inline functions into top-level, reusable public functions:

  1. get_pipeline_metadata — extracted from the top of pipeline_llm, computes num_virtual_stages, num_layers, input_weight, output_weight
  2. split_module — extracted from _build_stage_from_modules (nested inside pipeline_module_split), deep-copies and prunes the model to keep only specified modules
  3. get_pp_rank_to_stage_indices_mapping — extracted from _get_stage_indices (nested inside pipeline_module_split), maps PP rank to stage indices for loop/v-style schedules

Additionally, build_pipeline_schedule gains two new parameters: backward_requires_autograd and scale_grads (previously hardcoded to True and False respectively).

All three new functions and new __all__ exports are presumably intended for reuse by the graph PP experiments (PRs #2725#2727 in this stack).


Issues and Suggestions

1. pp_degree parameter lacks type annotation (get_pp_rank_to_stage_indices_mapping)

pipeline_parallel.py:466

def get_pp_rank_to_stage_indices_mapping(
    pp_rank: int,
    pp_degree,          # <-- missing type annotation
    pp_schedule: str,
    num_stages: int,
) -> tuple[int, ...]:

pp_degree should be annotated as int. Every other parameter in this function is typed.

Fix this →


2. assert used for user-facing validation in get_pp_rank_to_stage_indices_mapping

pipeline_parallel.py:486-488 and pipeline_parallel.py:493-495

Per the repo's code guidelines (CLAUDE.md): "assert only for internal invariants that indicate programmer error" and "ValueError for user-facing errors (bad config, invalid input)."

The two assert statements validate num_stages % pp_degree and stages_per_rank == 2. These were originally inside a private nested function (_get_stage_indices), where assert was arguably acceptable. Now that get_pp_rank_to_stage_indices_mapping is a public API, they should be ValueError raises, consistent with how pipeline_llm validates similar conditions (e.g. pipeline_parallel.py:173-180).

# Current (assert):
assert (
    num_stages % pp_degree == 0
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"

# Should be:
if num_stages % pp_degree != 0:
    raise ValueError(
        f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
    )

Same applies to the stages_per_rank == 2 assertion on line 493.

Fix this →


3. split_module broadens isinstance check vs. original without explanation

pipeline_parallel.py:425-427

The original _build_stage_from_modules checked:

isinstance(module_value, (nn.ModuleDict, nn.ModuleList))

The new split_module checks:

isinstance(module_value, (nn.ModuleDict, nn.ModuleList, ModuleDict, ModuleList))

Where ModuleDict and ModuleList are from torchtitan.protocols.module. This is a behavioral change compared to the original code path. If this is intentional (to support torchtitan's custom protocol types), it should be called out in the PR description. If it's accidental, it should be reverted to match the original behavior.


4. build_pipeline_schedule docstring not updated for new parameters

pipeline_parallel.py:216-226

The docstring still only documents the original four parameters. The new backward_requires_autograd and scale_grads parameters are not documented:

    """Builds a pipeline schedule for the given job configuration and stages.

    Args:
        parallelism (ParallelismConfig): The parallelism configuration.
        local_batch_size (int): The local batch size for computing microbatches.
        stages (list[PipelineStage]): The stages to be scheduled.
        loss_fn (Callable): The loss function.
        # Missing: backward_requires_autograd, scale_grads

At minimum, add brief descriptions for both new parameters so downstream consumers understand what they control.

Fix this →


5. Agree with reviewer: consider whether all new exports need to be public

pipeline_parallel.py:40-48

Three new symbols are added to __all__: get_pipeline_metadata, split_module, get_pp_rank_to_stage_indices_mapping. As @tianyu-l noted, it's worth considering whether all of these truly need to be part of the public API.

split_module is the strongest candidate for public exposure — it's a general-purpose utility that could be independently useful. get_pp_rank_to_stage_indices_mapping also has a clear independent use case.

get_pipeline_metadata is more debatable — it returns a fairly specific tuple (num_virtual_stages, num_layers, input_weight, output_weight) that's tightly coupled to the pipeline_llm flow. If the graph PP experiment only needs to call it once in a similar orchestration function, consider whether the caller should just inline the logic or whether this really warrants a public function. Without seeing the consuming PRs (#2725#2727), it's hard to judge definitively.


6. Stale blank line removal in imports

pipeline_parallel.py:24-25

The PR removes the blank line between the torch.distributed.pipelining.schedules imports and the torchtitan imports. This is a minor PEP 8 style regression — standard convention separates third-party/stdlib imports from local imports with a blank line. The original code had this correctly.


7. PR description needs more detail

Agreeing with the existing review comment: the PR description should explain why these functions are being extracted and what the intended reuse pattern is. Linking to or describing the consuming code in the stacked PRs (#2725#2727) would help reviewers understand the motivation.


Correctness Assessment

The refactoring itself appears mechanically correct — the extracted functions preserve the same logic as the original inline/nested versions. The main pipeline_llm function's behavior should be identical before and after this change. The new backward_requires_autograd and scale_grads parameters default to the same values as the original hardcoded ones (True and False), so existing callers are unaffected.

One thing to verify: the split_module isinstance broadening (issue #3 above) is the only behavioral difference that could affect existing code paths.


@sanketpurandare sanketpurandare marked this pull request as draft March 27, 2026 02:21
stack-info: PR: #2724, branch: sanketpurandare/stack/3
@sanketpurandare sanketpurandare changed the base branch from sanketpurandare/stack/2 to main March 27, 2026 18:57
@sanketpurandare sanketpurandare force-pushed the sanketpurandare/stack/3 branch from d3d14c2 to 2fb3b75 Compare March 27, 2026 18:57
@sanketpurandare sanketpurandare changed the base branch from main to sanketpurandare/stack/2 March 27, 2026 18:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants